{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hoQQiZDB6URn"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3vhAMaIOBIee"
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython.display as display\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' \n",
    "# 0 = all messages are logged (default behavior)\n",
    "# 1 = INFO messages are not printed\n",
    "# 2 = INFO and WARNING messages are not printed\n",
    "# 3 = INFO, WARNING, and ERROR messages are not printed\n",
    "\n",
    "#On Mac you may encounter an error related to OMP, this is a workaround, but slows down the code\n",
    "os.environ['KMP_DUPLICATE_LIB_OK']='True' #https://github.com/dmlc/xgboost/issues/1715"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QGXxBuPyKJw1"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "import dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KT6CcaqgQewg"
   },
   "outputs": [],
   "source": [
    "AUTOTUNE = tf.data.experimental.AUTOTUNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZJ20R66fzktl"
   },
   "outputs": [],
   "source": [
    "tf.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lAmtzsnjDNhB"
   },
   "source": [
    "Define the dataset directory and give it a name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rN-Pc6Zd6awg"
   },
   "outputs": [],
   "source": [
    "base_dir = \"dataset\"\n",
    "dataset_name = \"my_openbot\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, all datasets in the training data directory will be used. You can also specify only specific datasets, for example:\n",
    "```\n",
    "train_datasets = [\"my_openbot_1\", \"my_openbot_2\"]\n",
    "test_datasets = [\"my_openbot_3\"]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_dir = os.path.join(base_dir,\"train_data\")\n",
    "test_data_dir = os.path.join(base_dir, \"test_data\")\n",
    "\n",
    "train_datasets = [d for d in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, d))] \n",
    "test_datasets = [d for d in os.listdir(test_data_dir) if os.path.isdir(os.path.join(test_data_dir, d))] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Train Datasets: ',len(train_datasets))\n",
    "print('Test Datasets: ',len(test_datasets))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running this for the first time will take some time. This code will match image frames to the controls (labels) and indicator signals (commands).  By default, data samples where the vehicle was stationary will be removed. If this is not desired, you need to pass `remove_zeros=False`. If you have made any changes to the sensor files, changed `remove_zeros` or moved your dataset to a new directory, you need to pass `redo_matching=True`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import associate_frames\n",
    "max_offset = 1e3 #1ms\n",
    "train_frames = associate_frames.match_frame_ctrl_cmd(train_data_dir, \n",
    "                                                     train_datasets, \n",
    "                                                     max_offset, \n",
    "                                                     redo_matching=False, \n",
    "                                                     remove_zeros=True)\n",
    "test_frames = associate_frames.match_frame_ctrl_cmd(test_data_dir, \n",
    "                                                    test_datasets, \n",
    "                                                    max_offset, \n",
    "                                                    redo_matching=False, \n",
    "                                                    remove_zeros=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_count_train = len(train_frames) \n",
    "image_count_test = len(test_frames) \n",
    "print(\"There are %d train images and %d test images\" %(image_count_train, image_count_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "1zf695or-Flq"
   },
   "source": [
    "You may have to tune the learning rate and batch size depending on your available compute resources and dataset. As a general rule of thumb, if you increase the batch size by a factor of n, you can increase the learning rate by a factor of sqrt(n). For debugging and hyperparamter tuning, you can set the number of epochs to a small value like 10. If you want to train a model which will achieve good performance, you should set it to 50 or more. In our paper we used 100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_BATCH_SIZE = 16 #128\n",
    "TEST_BATCH_SIZE = 16 #128\n",
    "LR = 0.0001 #0.0003\n",
    "NUM_EPOCHS = 10 #100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Don't change these unless you know what you are doing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input dimensions\n",
    "IMG_HEIGHT = 720\n",
    "IMG_WIDTH = 1280\n",
    "#Offset dimensions (crop)\n",
    "OFFSET_IMG_HEIGHT = 240 \n",
    "OFFSET_IMG_WIDTH = 0\n",
    "#Target dimensions\n",
    "TARGET_IMG_HEIGHT = IMG_HEIGHT-OFFSET_IMG_HEIGHT\n",
    "TARGET_IMG_WIDTH = IMG_WIDTH-OFFSET_IMG_WIDTH\n",
    "#Network dimensions\n",
    "NETWORK_IMG_HEIGHT = TARGET_IMG_HEIGHT//5 \n",
    "NETWORK_IMG_WIDTH = TARGET_IMG_WIDTH//5 \n",
    "\n",
    "STEPS_PER_EPOCH = np.ceil(image_count_train/TRAIN_BATCH_SIZE)\n",
    "\n",
    "BN = True\n",
    "FLIP_AUG = False\n",
    "CMD_AUG = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AxS1cLzM8mEp"
   },
   "source": [
    "## Load using `tf.data`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ylj9fgkamgWZ"
   },
   "source": [
    "Inspired by:\n",
    "https://github.com/tensorflow/docs/blob/master/site/en/tutorials/load_data/images.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IIG5CPaULegg"
   },
   "source": [
    "To load the files as a `tf.data.Dataset` first create a dataset of the file paths. Depending on dataset size, this may take some time. If you encounter issues, you can use the commented lines instead. However, this will take **much** longer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list_train_ds = tf.data.Dataset.list_files(train_frames)\n",
    "# list_test_ds = tf.data.Dataset.list_files(test_frames)\n",
    "\n",
    "list_train_ds = tf.data.Dataset.list_files([str(train_data_dir+'/'+dataset+'/*/images/*') \n",
    "                                            for dataset in train_datasets])\n",
    "list_test_ds = tf.data.Dataset.list_files([str(test_data_dir+'/'+dataset+'/*/images/*') \n",
    "                                           for dataset in test_datasets])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Random dataset samples for verification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "coORvEH-NGwc"
   },
   "outputs": [],
   "source": [
    "for f in list_train_ds.take(5):\n",
    "  print(f.numpy())\n",
    "print()\n",
    "for f in list_test_ds.take(5):\n",
    "  print(f.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = dataloader.dataloader(train_data_dir, train_datasets)\n",
    "print (\"Number of train samples: %d\" %len(train_data.labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = dataloader.dataloader(test_data_dir, test_datasets)\n",
    "print (\"Number of test samples: %d\" %len(test_data.labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some functions for augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def augment_img(img):\n",
    "  \"\"\"Color augmentation\n",
    "\n",
    "  Args:\n",
    "    img: input image\n",
    "\n",
    "  Returns:\n",
    "    img: augmented image\n",
    "  \"\"\"\n",
    "  img = tf.image.random_hue(img, 0.08)\n",
    "  img = tf.image.random_saturation(img, 0.6, 1.6)\n",
    "  img = tf.image.random_brightness(img, 0.05)\n",
    "  img = tf.image.random_contrast(img, 0.7, 1.3)\n",
    "  return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def augment_cmd(cmd):\n",
    "  \"\"\"\n",
    "  Command augmentation\n",
    "\n",
    "  Args:\n",
    "    cmd: input command\n",
    "\n",
    "  Returns:\n",
    "    cmd: augmented command\n",
    "  \"\"\"\n",
    "  if not (cmd > 0 or cmd < 0):\n",
    "    coin = tf.random.uniform(shape=[1],minval=0, maxval=1, dtype=tf.dtypes.float32)\n",
    "    if (coin < 0.25):\n",
    "      cmd = -1.0\n",
    "    elif coin < 0.5:\n",
    "      cmd = 1.0\n",
    "  return cmd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flip_sample(img,cmd,label):\n",
    "  coin = tf.random.uniform(shape=[1],minval=0, maxval=1, dtype=tf.dtypes.float32)\n",
    "  if coin < 0.5:\n",
    "    img = tf.image.flip_left_right(img)\n",
    "    cmd = -cmd\n",
    "    label = tf.reverse(label, axis=[0])\n",
    "  return img,cmd,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MGlq4IP4Aktb"
   },
   "outputs": [],
   "source": [
    "def decode_resize_img(img):\n",
    "  img = decode_img (img)\n",
    "  img = tf.image.crop_to_bounding_box(img,OFFSET_IMG_HEIGHT,OFFSET_IMG_WIDTH,\n",
    "    TARGET_IMG_HEIGHT,\n",
    "    TARGET_IMG_WIDTH)\n",
    "  # resize the image to the desired size.\n",
    "  img = tf.image.resize(img, [NETWORK_IMG_HEIGHT, NETWORK_IMG_WIDTH])\n",
    "  return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_img(img):\n",
    "  # convert the compressed string to a 3D uint8 tensor\n",
    "  img = tf.image.decode_jpeg(img, channels=3)\n",
    "  # Use `convert_image_dtype` to convert to floats in the [0,1] range.\n",
    "  img = tf.image.convert_image_dtype(img, tf.float32)\n",
    "  return img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "91CPfUUJ_8SZ"
   },
   "source": [
    "Short pure-tensorflow function that converts a file path to an (image_data, label) pair:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-xhBRgvNqRRe"
   },
   "outputs": [],
   "source": [
    "def process_train_path(file_path):\n",
    "  cmd, label = train_data.get_label(tf.strings.regex_replace(file_path,\"[/\\\\\\\\]\",\"/\"))\n",
    "  # load the raw data from the file as a string\n",
    "  img = tf.io.read_file(file_path)\n",
    "  img = decode_img(img)\n",
    "  img = augment_img(img)\n",
    "  if FLIP_AUG:\n",
    "    img, cmd, label = flip_sample(img, cmd, label)\n",
    "  if CMD_AUG:\n",
    "    cmd = augment_cmd(cmd)\n",
    "  return (img, cmd), label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_test_path(file_path):\n",
    "  cmd, label = test_data.get_label(tf.strings.regex_replace(file_path,\"[/\\\\\\\\]\",\"/\"))\n",
    "  # load the raw data from the file as a string\n",
    "  img = tf.io.read_file(file_path)\n",
    "  img = decode_img(img)\n",
    "  return (img, cmd), label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "S9a5GpsUOBx8"
   },
   "source": [
    "Use `Dataset.map` to create a dataset of `image, label` pairs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3SDhbo8lOBQv"
   },
   "outputs": [],
   "source": [
    "# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.\n",
    "labeled_ds = list_train_ds.map(process_train_path, num_parallel_calls=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kxrl0lGdnpRz"
   },
   "outputs": [],
   "source": [
    "for (image, cmd), label in labeled_ds.take(1):\n",
    "  print(\"Image shape: \", image.numpy().shape)\n",
    "  print(\"Command: \", cmd.numpy())\n",
    "  print(\"Label: \", label.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-YKnrfAeZV10"
   },
   "outputs": [],
   "source": [
    "train_ds = utils.prepare_for_training(ds=labeled_ds, \n",
    "                                      batch_sz=TRAIN_BATCH_SIZE, \n",
    "                                      shuffle_buffer_sz=100*TRAIN_BATCH_SIZE, \n",
    "                                      prefetch_buffer_sz=10*TRAIN_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UN_Dnl72YNIj",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "(image_batch, cmd_batch), label_batch = next(iter(train_ds))\n",
    "utils.show_train_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = list_test_ds.map(process_test_path, num_parallel_calls=4)\n",
    "test_ds = test_ds.batch(TEST_BATCH_SIZE)\n",
    "test_ds = test_ds.prefetch(buffer_size=10*TRAIN_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import models\n",
    "import losses\n",
    "import metrics\n",
    "import callbacks\n",
    "\n",
    "model = models.cil_mobile(NETWORK_IMG_WIDTH,NETWORK_IMG_HEIGHT,BN)\n",
    "loss_fn = losses.sq_weighted_mse_angle \n",
    "metric_list = ['MeanAbsoluteError', metrics.direction_metric, metrics.angle_metric]\n",
    "optimizer = tf.keras.optimizers.Adam(lr=LR)\n",
    "\n",
    "model.compile(optimizer=optimizer,\n",
    "          loss=loss_fn, \n",
    "          metrics=metric_list)\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = dataset_name + \"_\" + model.name + \"_lr\" + str(LR) + \"_bz\" + str(TRAIN_BATCH_SIZE)\n",
    "if BN:\n",
    "    MODEL_NAME += \"_bn\"\n",
    "if FLIP_AUG:\n",
    "    MODEL_NAME += \"_flip\"\n",
    "if CMD_AUG:\n",
    "    MODEL_NAME += \"_cmd\"    \n",
    "    \n",
    "checkpoint_path = os.path.join('models', MODEL_NAME, 'checkpoints')\n",
    "log_path = os.path.join('models',MODEL_NAME,'logs')\n",
    "print(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = model.fit(train_ds, \n",
    "                    epochs=NUM_EPOCHS, \n",
    "                    steps_per_epoch=STEPS_PER_EPOCH, \n",
    "                    validation_data=test_ds, \n",
    "                    callbacks=[callbacks.checkpoint_cb(checkpoint_path),\n",
    "                               callbacks.tensorboard_cb(log_path),\n",
    "                               callbacks.logger_cb(log_path)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot metrics and loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['MeanAbsoluteError'], label='mean_absolute_error')\n",
    "plt.plot(history.history['val_MeanAbsoluteError'], label = 'val_mean_absolute_error')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Mean Absolute Error')\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(log_path,'error.png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['direction_metric'], label='direction_metric')\n",
    "plt.plot(history.history['val_direction_metric'], label = 'val_direction_metric')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Direction Metric')\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(log_path,'direction.png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['angle_metric'], label='angle_metric')\n",
    "plt.plot(history.history['val_angle_metric'], label = 'val_angle_metric')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Angle Metric')\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(log_path,'angle.png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.plot(history.history['loss'], label='loss')\n",
    "plt.plot(history.history['val_loss'], label = 'val_loss')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend(loc='lower right')\n",
    "plt.savefig(os.path.join(log_path,'loss.png'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save tf lite models for best and last checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_index = np.argmax(np.array(history.history['val_angle_metric']) \\\n",
    "                     + np.array(history.history['val_direction_metric']))\n",
    "best_checkpoint = str(\"cp-%04d.ckpt\" % (best_index+1))\n",
    "best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint)\n",
    "utils.save_tflite (best_tflite, checkpoint_path, \"best\")\n",
    "print(\"Best Checkpoint (val_angle: %s, val_direction: %s): %s\" \\\n",
    "      %(history.history['val_angle_metric'][best_index],\\\n",
    "        history.history['val_direction_metric'][best_index],\\\n",
    "        best_checkpoint))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_checkpoint = sorted([d for d in os.listdir(checkpoint_path) if os.path.isdir(os.path.join(checkpoint_path, d))])[-1]\n",
    "last_tflite = utils.generate_tflite(checkpoint_path, last_checkpoint)\n",
    "utils.save_tflite (last_tflite, checkpoint_path, \"last\")\n",
    "print(\"Last Checkpoint (val_angle: %s, val_direction: %s): %s\" \\\n",
    "      %(history.history['val_angle_metric'][-1], \\\n",
    "        history.history['val_direction_metric'][-1], \\\n",
    "        last_checkpoint))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate the best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model = utils.load_model(os.path.join(checkpoint_path,best_checkpoint),loss_fn,metric_list)\n",
    "test_loss, test_acc, test_dir, test_ang = best_model.evaluate(test_ds, steps=image_count_test/TEST_BATCH_SIZE, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLES = 15\n",
    "(image_batch, cmd_batch), label_batch = next(iter(test_ds))\n",
    "pred_batch = best_model.predict( (tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]), tf.slice(cmd_batch, [0], [NUM_SAMPLES])) )\n",
    "utils.show_test_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.compare_tf_tflite(best_model,best_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the notebook as HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#time.sleep(30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.save_notebook()\n",
    "current_file = 'policy_learning.ipynb'\n",
    "output_file = os.path.join(log_path,'notebook.html')\n",
    "utils.output_HTML(current_file, output_file)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "",
    "kind": "local"
   },
   "name": "images.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
